import time
import torch
from utils.model import Enhanced_large_GCNv3,Simplified_GCN
import utils.config as config
from generate_data.load_data_h5py import reload_config
import torch.nn as nn
import matplotlib.pyplot as plt
import os
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.data.distributed import DistributedSampler
from torch_geometric.data import DataLoader
import torch.optim as optim
import gc
import sys


# 初始化分布式环境
def setup(rank, world_size):
    os.environ['MASTER_ADDR'] = 'localhost'  # 主节点地址
    os.environ['MASTER_PORT'] = '12356'      # 主节点端口
    # torch.cuda.set_device(rank)
    dist.init_process_group("nccl", rank=rank, world_size=world_size)

# 清理分布式环境
def cleanup():
    dist.destroy_process_group()

# 训练函数
def train(rank, world_size):
    torch.cuda.set_device(rank)
    torch.tensor([0.0], device=f'cuda:{rank}')  # 触发 CUBLAS 初始化
    

    print(f"Running DDP on rank {rank}.",flush=True)
    setup(rank, world_size)

    # 初始化模型
    if config.train_type == 'std':
        model = Simplified_GCN(num_node_features=9, hidden_dim=32, output_dim=2*config.Nbus-1).to(rank)
        Save_path = config.path
    else:
        print(f'output_dim: {2*config.Nbus-1}')
        model = Enhanced_large_GCNv3(num_node_features=9, hidden_dim=128, output_dim=2*config.Nbus-1).to(rank)
        Save_path = config.pre_path

    print(f"[Rank {rank}] Model and data are on: {next(model.parameters()).device}")

    model = DDP(model, device_ids=[rank])  # 使用 DDP 包装模型

    # 优化器和学习率调度器
    optimizer = optim.Adam(model.parameters(), lr=config.lr)

    # 加载数据
    
    train_dataset = reload_config()  # 假设返回的是 PyG 的 Dataset 对象
    train_sampler = DistributedSampler(train_dataset, num_replicas=world_size, rank=rank)
    train_loader = DataLoader(train_dataset, batch_size=config.batch_size, sampler=train_sampler,pin_memory = True)

    criterion = nn.MSELoss()
    epochs = config.epochs

    

    least_loss = 1e8

    # 记录损失和 MSE
    Fig_loss = []
    Fig_mse = []

    # 记录整个训练过程的开始时间
    total_start_time = time.time()

    # 训练循环
    for epoch in range(epochs):
        torch.cuda.empty_cache()
        train_sampler.set_epoch(epoch)  # 设置 epoch 以打乱数据
        model.train()
        running_loss = 0.0
        running_mse = 0.0

        for step, data in enumerate(train_loader):

            graph, voltage = data
            graph = graph.to(rank)  # 将图数据（Data 对象）移动到对应的设备上
            voltage = voltage.to(rank)  # 如果有电压数据，也移到设备上

            # 检查模型的输出
            _, output_y = model(graph)

             # 前向传播前显存
            torch.cuda.synchronize()

            # 前向传播
            _, output_y = model(graph)

            mse = criterion(output_y, voltage)
            loss = config.mse_weight * mse

            # 反向传播
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()


            running_loss += loss.item()
            running_mse += mse.item()

            # 删除不再需要的变量，帮助垃圾回收释放显存
            del graph, voltage, output_y, loss, mse
        
         # epoch结束后调用垃圾回收和释放空闲缓存
        gc.collect()
        torch.cuda.empty_cache()

        # 每 10 个 epoch 调用一次 test_model 函数
        if epoch >= config.epochs - 50 and epoch % 10 == 0 and rank == 0:
            print(f"Epoch {epoch}: Running test_model for validation...", flush=True)
            # test_model()  # 调用 test_model 函数
            gc.collect()
            torch.cuda.empty_cache()

        # 打印日志（仅在主进程）
        if rank == 0:
            print(f'Epoch [{epoch + 1}/{epochs}], Loss: {running_loss:.4f}, MSE: {running_mse:.4f}',flush=True)
            sys.stdout.flush()
            if epoch > 30:
                Fig_loss.append(running_loss)
                Fig_mse.append(running_mse)
                if(running_loss<least_loss):
                    torch.save(model.module.state_dict(), Save_path)
                    least_loss = running_loss
    
    total_end_time = time.time()
    total_training_time = total_end_time - total_start_time
    avg_epoch_time = total_training_time / epochs  # 计算平均每个 epoch 的时间

    # 保存模型（仅在主进程）
    if rank == 0:
        # 绘制损失和 MSE 曲线
        plt.figure(figsize=(8, 6))
        plt.plot(Fig_loss, label='Loss', color='blue', linestyle='-', marker='o')
        plt.title('Training Loss Curve')
        plt.xlabel('Epochs')
        plt.ylabel('Loss')
        plt.legend()
        plt.tight_layout()
        plt.savefig('Fig_loss.png')
        plt.close()

        plt.figure(figsize=(8, 6))
        plt.plot(Fig_mse, label='MSE', color='red', linestyle='-', marker='x')
        plt.title('Training MSE Curve')
        plt.xlabel('Epochs')
        plt.ylabel('MSE')
        plt.legend()
        plt.tight_layout()
        plt.savefig('Fig_mse.png')
        plt.close()

    # 清理分布式环境
    cleanup()